# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""the module is used to process images."""

import numpy as np

import mindspore.dataset.vision.c_transforms as C
import mindspore.dataset.transforms.py_transforms as PY

from mindvision.common.utils.class_factory import ClassFactory, ModuleType


@ClassFactory.register(ModuleType.PIPELINE)
class Decode:
    """Wrap mindspore Decode operation"""

    def __init__(self, decode_mode='C'):
        self.decode_mode = decode_mode
        if decode_mode == 'C':
            self.decode = C.Decode()
        else:
            self.decode = PY.Decode()

    def __call__(self, results):
        image = results['image']
        img_dec = self.decode(image)
        results['image'] = img_dec

        # shape format as HW
        if self.decode_mode == 'C':
            results['image_shape'] = img_dec.shape[:2]
        else:
            results['image_shape'] = (img_dec.size[1], img_dec.size[0])
        return results


@ClassFactory.register(ModuleType.PIPELINE)
class HWC2CHW:
    """ Wrap mindspore HWC2CHW operation. """

    def __init__(self):
        self.hwc2chw = C.HWC2CHW()

    def __call__(self, results):
        image = results['image']
        img_des = self.hwc2chw(image)
        results['image'] = img_des
        return results


@ClassFactory.register(ModuleType.PIPELINE)
class ImgRgbToBgr:
    """Convert rgb to bgr."""
    def __init__(self):
        pass

    def __call__(self, results):
        image = results['image']
        image_bgr = image.copy()
        image_bgr[:, :, 0] = image[:, :, 2]
        image_bgr[:, :, 1] = image[:, :, 1]
        image_bgr[:, :, 2] = image[:, :, 0]
        results['image'] = image_bgr

        return results


@ClassFactory.register(ModuleType.PIPELINE)
class Transpose:
    """Transpose Operation Class"""
    def __init__(self, perm=(2, 0, 1)):
        self.perm = perm

    def __call__(self, results):
        """Transpose operation for image."""
        img = results.get("image")
        img_data = img.transpose(self.perm).copy()

        results['image'] = img_data

        return results


@ClassFactory.register(ModuleType.PIPELINE)
class EvalFormat:
    """ Eval Format """
    def __init__(self):
        pass

    def __call__(self, data_tuple):

        image = data_tuple[0]
        image_id = data_tuple[1]
        img_shape = np.array(image.shape[:2], np.int32)
        results = {'image': image,
                   "image_id": image_id,
                   "image_shape": img_shape}
        return results


@ClassFactory.register(ModuleType.PIPELINE)
class Format:
    """Format the input data.
    Args:
       pad_max_number : pad config.
    """

    def __init__(self,
                 is_infer=False,
                 pad_max_number=None):
        self.is_infer = is_infer
        self.pad_max_number = pad_max_number

    def __call__(self, data_tuple):
        image = data_tuple[0]
        img_infos = data_tuple[1]  # default box, label, iscrowd
        image_shape = np.array(image.shape[:2], np.int32)

        results = {"image": image,
                   "image_shape": image_shape}

        if not self.is_infer:
            results = self.pad_gt(img_infos, results)
        else:
            results["image_id"] = img_infos

        return results

    def pad_gt(self, annotations, results):
        """Pad ground truth boxes."""
        gt_box = annotations[:, :4]
        gt_label = annotations[:, 4]
        gt_iscrowd = annotations[:, 5]

        if self.pad_max_number is not None:
            pad_max_number = self.pad_max_number
            gt_box_new = np.pad(
                gt_box, ((0, pad_max_number - annotations.shape[0]), (0, 0)),
                mode="constant", constant_values=0
            )

            gt_label_new = np.pad(
                gt_label, ((0, pad_max_number - annotations.shape[0]),),
                mode="constant", constant_values=-1
            )

            gt_iscrowd_new = np.pad(
                gt_iscrowd, ((0, pad_max_number - annotations.shape[0]),),
                mode="constant", constant_values=1
            )

            gt_iscrowd_new_revert = (~(gt_iscrowd_new.astype(np.bool))).astype(np.int32)
        else:
            gt_box_new = gt_box
            gt_label_new = gt_label
            gt_iscrowd_new_revert = (~(gt_iscrowd.astype(np.bool))).astype(
                np.int32)

        results["bboxes"] = gt_box_new
        results["labels"] = gt_label_new
        results["valid_num"] = gt_iscrowd_new_revert

        return results


@ClassFactory.register(ModuleType.PIPELINE)
class Collect:
    """Collect output image data.Convert dict to tuple.
    Args:
        output_orders (list) : output order.
        output_type_dict (dict) : output types.
    """
    _np_type_dict = {'bool': np.bool,
                     'int8': np.int8,
                     'int16': np.int16,
                     'int32': np.int32,
                     'int64': np.int64,
                     'uint8': np.uint8,
                     'uint16': np.uint16,
                     'uint32': np.uint32,
                     'uint64': np.uint64,
                     'float16': np.float16,
                     'float32': np.float32,
                     'float64': np.float64}

    def __init__(self, output_orders, output_type_dict=None):
        """Constructor for Collate."""
        self.output_type_dict = output_type_dict
        self.output_orders = output_orders

    def np_type_cast(self, results):
        """Convert numpy type."""
        if self.output_type_dict is not None:
            for k in self.output_type_dict:
                if k in results:
                    results[k] = results[k].astype(self._np_type_dict[self.output_type_dict[k]])
        return results

    def __call__(self, results):
        results = self.np_type_cast(results)
        result = []
        for k in self.output_orders:
            if k in results:
                result.append(results[k])
            else:
                result.append([])
        result = [results[k] for k in self.output_orders]
        return tuple(result)
